{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "36ca04d4",
   "metadata": {},
   "source": [
    "# Loss Landscapes on CIFAR"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5555ccd1",
   "metadata": {},
   "source": [
    "We visualize loss landscapes using [filter normalization](https://arxiv.org/abs/1712.09913) to show that ***self-attentions flatten loss landscapes*** as shown in Fig 1.a and Fig C.4.a."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0937b1c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "# check whether run in Colab\n",
    "root = \".\"\n",
    "if \"google.colab\" in sys.modules:\n",
    "    print(\"Running in Colab.\")\n",
    "    !pip3 install matplotlib\n",
    "    !pip3 install einops==0.4.1\n",
    "    !pip3 install timm==0.5.4\n",
    "    !pip3 install pandas==1.3.2\n",
    "    !git clone https://github.com/xxxnell/how-do-vits-work.git\n",
    "    root = \"./how-do-vits-work\"\n",
    "    sys.path.append(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f199c11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import yaml\n",
    "import copy\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import models\n",
    "import ops.tests as tests\n",
    "import ops.datasets as datasets\n",
    "import ops.loss_landscapes as lls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48de8d8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# config_path = \"%s/configs/cifar10_vit.yaml\" % root\n",
    "config_path = \"%s/configs/cifar100_vit.yaml\" % root\n",
    "# config_path = \"%s/configs/imagenet_vit.yaml\" % root\n",
    "\n",
    "with open(config_path) as f:\n",
    "    args = yaml.load(f)\n",
    "    print(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50693515",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_args = copy.deepcopy(args).get(\"dataset\")\n",
    "train_args = copy.deepcopy(args).get(\"train\")\n",
    "val_args = copy.deepcopy(args).get(\"val\")\n",
    "model_args = copy.deepcopy(args).get(\"model\")\n",
    "optim_args = copy.deepcopy(args).get(\"optim\")\n",
    "env_args = copy.deepcopy(args).get(\"env\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5113c601",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "dataset_train, dataset_test = datasets.get_dataset(**dataset_args, download=True)\n",
    "dataset_name = dataset_args[\"name\"]\n",
    "num_classes = len(dataset_train.classes)\n",
    "\n",
    "dataset_train = DataLoader(dataset_train, \n",
    "                           shuffle=True, \n",
    "                           num_workers=train_args.get(\"num_workers\", 4), \n",
    "                           batch_size=train_args.get(\"batch_size\", 128))\n",
    "dataset_test = DataLoader(dataset_test, \n",
    "                          num_workers=val_args.get(\"num_workers\", 4), \n",
    "                          batch_size=val_args.get(\"batch_size\", 128))\n",
    "\n",
    "print(\"Train: %s, Test: %s, Classes: %s\" % (\n",
    "    len(dataset_train.dataset), \n",
    "    len(dataset_test.dataset), \n",
    "    num_classes\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4886d2e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from timm.data import Mixup\n",
    "\n",
    "def mixup_function(train_args):\n",
    "    train_args = copy.deepcopy(train_args)\n",
    "    smoothing = train_args.get(\"smoothing\", 0.0)\n",
    "    mixup_args = train_args.get(\"mixup\", None)\n",
    "\n",
    "    mixup_function = Mixup(\n",
    "        **mixup_args,\n",
    "        label_smoothing=smoothing,\n",
    "    ) if mixup_args is not None else None\n",
    "    return mixup_function\n",
    "\n",
    "\n",
    "transform = mixup_function(train_args)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7759909c",
   "metadata": {},
   "source": [
    "## Load Pretrained Models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9bf668b",
   "metadata": {},
   "source": [
    "The cells below provide the snippets for ResNet-50 and ViT-Ti:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55b28fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# download and load a pretrained model for CIFAR-100\n",
    "url = \"https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/resnet_50_cifar100_691cc9a9e4.pth.tar\"\n",
    "path = \"checkpoints/resnet_50_cifar100_691cc9a9e4.pth.tar\"\n",
    "models.download(url=url, path=path)\n",
    "\n",
    "name = \"resnet_50\"\n",
    "uid = \"691cc9a9e4\"\n",
    "model = models.get_model(name, num_classes=num_classes,  # timm does not provide a ResNet for CIFAR\n",
    "                         stem=model_args.get(\"stem\", False))\n",
    "map_location = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "checkpoint = torch.load(path, map_location=map_location)\n",
    "model.load_state_dict(checkpoint[\"state_dict\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ea7108b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import timm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# download and load a pretrained model for CIFAR-100\n",
    "url = \"https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/vit_ti_cifar100_9857b21357.pth.tar\"\n",
    "path = \"checkpoints/vit_ti_cifar100_9857b21357.pth.tar\"\n",
    "models.download(url=url, path=path)\n",
    "\n",
    "model = timm.models.vision_transformer.VisionTransformer(\n",
    "    num_classes=num_classes, img_size=32, patch_size=2,  # for CIFAR\n",
    "    embed_dim=192, depth=12, num_heads=3, qkv_bias=False,  # for ViT-Ti \n",
    ")\n",
    "model.name = \"vit_ti\"\n",
    "uid = \"9857b21357\"\n",
    "models.stats(model)\n",
    "map_location = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "checkpoint = torch.load(path, map_location=map_location)\n",
    "model.load_state_dict(checkpoint[\"state_dict\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e99d1e33",
   "metadata": {},
   "source": [
    "## Investigate the Loss Landscape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71108a9b",
   "metadata": {},
   "source": [
    "We measure not only NLL but also \"L2\" on \"augmented training datasets\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "192d134a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "scale = 1e-0\n",
    "n = 21\n",
    "gpu = torch.cuda.is_available()\n",
    "\n",
    "metrics_grid = lls.get_loss_landscape(\n",
    "    model, 1, dataset_train, transform=transform,\n",
    "    kws=[\"pos_embed\", \"relative_position\"],\n",
    "    x_min=-1.0 * scale, x_max=1.0 * scale, n_x=n, y_min=-1.0 * scale, y_max=1.0 * scale, n_y=n, gpu=gpu,\n",
    ")\n",
    "leaderboard_path = os.path.join(\"leaderboard\", \"logs\", dataset_name, model.name)\n",
    "Path(leaderboard_path).mkdir(parents=True, exist_ok=True)\n",
    "metrics_dir = os.path.join(leaderboard_path, \"%s_%s_%s_x%s_losslandscape.csv\" % (dataset_name, model.name, uid, int(1 / scale)))\n",
    "metrics_list = [[*grid, *metrics] for grid, metrics in metrics_grid.items()]\n",
    "tests.save_metrics(metrics_dir, metrics_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d408ee71",
   "metadata": {},
   "source": [
    "## Plot the Loss Landscape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcb9e4d4",
   "metadata": {},
   "source": [
    "For the given loss landscape raw data, we visualize the loss (NLL + L2) landscape."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "690f92a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm\n",
    "\n",
    "\n",
    "# load losslandscape raw data of ResNet-50 or ViT-Ti\n",
    "names = [\"x\", \"y\", \"l1\", \"l2\", \"NLL\", \"Cutoff1\", \"Cutoff2\", \"Acc\", \"Acc-90\", \"Unc\", \"Unc-90\", \"IoU\", \"IoU-90\", \"Freq\", \"Freq-90\", \"Top-5\", \"Brier\", \"ECE\", \"ECSE\"]\n",
    "# path = \"%s/resources/results/cifar100_resnet_dnn_50_losslandscape.csv\" % root  # for ResNet-50\n",
    "path = \"%s/resources/results/cifar100_vit_ti_losslandscape.csv\" % root  # for ViT-Ti\n",
    "data = pd.read_csv(path, names=names)\n",
    "data[\"loss\"] = data[\"NLL\"] + optim_args[\"weight_decay\"] * data[\"l2\"]  # NLL + l2\n",
    "\n",
    "# prepare data\n",
    "p = int(math.sqrt(len(data)))\n",
    "shape = [p, p]\n",
    "xs = data[\"x\"].to_numpy().reshape(shape) \n",
    "ys = data[\"y\"].to_numpy().reshape(shape)\n",
    "zs = data[\"loss\"].to_numpy().reshape(shape)\n",
    "\n",
    "zs = zs - zs[np.isfinite(zs)].min()\n",
    "zs[zs > 42] = np.nan\n",
    "\n",
    "norm = plt.Normalize(zs[np.isfinite(zs)].min(), zs[np.isfinite(zs)].max())  # normalize to [0,1]\n",
    "colors = cm.plasma(norm(zs))\n",
    "rcount, ccount, _ = colors.shape\n",
    "\n",
    "fig = plt.figure(figsize=(4.2, 4), dpi=120)\n",
    "ax = fig.gca(projection=\"3d\")\n",
    "ax.view_init(elev=15, azim=15)  # angle\n",
    "\n",
    "# make the panes transparent\n",
    "ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
    "ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
    "ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
    "# make the grid lines transparent\n",
    "ax.xaxis._axinfo[\"grid\"]['color'] =  (1,1,1,0)\n",
    "ax.yaxis._axinfo[\"grid\"]['color'] =  (1,1,1,0)\n",
    "ax.zaxis._axinfo[\"grid\"]['color'] =  (1,1,1,0)\n",
    "\n",
    "surf = ax.plot_surface(\n",
    "    xs, ys, zs, \n",
    "    rcount=rcount, ccount=ccount,\n",
    "    facecolors=colors, shade=False,\n",
    ")\n",
    "surf.set_facecolor((0,0,0,0))\n",
    "\n",
    "# remove white spaces\n",
    "adjust_lim = 0.8\n",
    "ax.set_xlim(-1 * adjust_lim, 1 * adjust_lim)\n",
    "ax.set_ylim(-1 * adjust_lim, 1 * adjust_lim)\n",
    "ax.set_zlim(10, 32)\n",
    "fig.subplots_adjust(left=0, right=1, bottom=0, top=1)\n",
    "ax.axis('off')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15c3093f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
